package coordinator

import (
	"context"
	"fmt"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/maxpert/marmot/hlc"
	"github.com/maxpert/marmot/protocol"
)

// TestMultiStatementTransaction verifies BEGIN/COMMIT accumulates statements
// and executes them as a single distributed transaction
func TestMultiStatementTransaction(t *testing.T) {
	// Setup: 3-node cluster
	nodes := []uint64{1, 2, 3}
	nodeProvider := newMockNodeProvider(nodes)
	replicator := newMockReplicator()
	clock := hlc.NewClock(1)

	coordinator := NewWriteCoordinator(1, nodeProvider, replicator, replicator, 5*time.Second, clock)

	// Simulate BEGIN - start transaction
	txnState := &TransactionState{
		TxnID:      12345,
		StartTS:    hlc.Timestamp{WallTime: 1000, Logical: 0},
		Statements: []protocol.Statement{},
		Database:   "testdb",
	}

	// Add multiple statements
	stmt1 := protocol.Statement{SQL: "INSERT INTO users VALUES (1, 'Alice')", Type: protocol.StatementInsert}
	stmt2 := protocol.Statement{SQL: "UPDATE users SET name='Bob' WHERE id=1", Type: protocol.StatementUpdate}
	txnState.Statements = append(txnState.Statements, stmt1, stmt2)

	// Simulate COMMIT - execute all statements via 2PC
	txn := &Transaction{
		ID:               txnState.TxnID,
		NodeID:           1,
		Statements:       txnState.Statements,
		StartTS:          txnState.StartTS,
		WriteConsistency: protocol.ConsistencyQuorum,
		Database:         txnState.Database,
	}

	// Execute transaction
	ctx := context.Background()
	err := coordinator.WriteTransaction(ctx, txn)
	if err != nil {
		t.Fatalf("Expected transaction to succeed, got error: %v", err)
	}

	// Allow pending goroutines to complete (coordinator uses early quorum exit)
	time.Sleep(50 * time.Millisecond)

	// Verify: Both PREPARE and COMMIT phases executed
	calls := replicator.GetCalls()

	// With quorum optimization:
	// - PREPARE sent to all 3 nodes
	// - Coordinator waits for quorum (2) and exits early
	// - COMMIT sent only to nodes that ACKed PREPARE (2 nodes)
	// - Total: 3 PREPARE + 2 COMMIT = 5 calls (minimum for quorum)
	//
	// In rare cases where all 3 PREPARE responses arrive before quorum check,
	// we might get 6 calls (3 PREPARE + 3 COMMIT). Both are valid.
	if len(calls) < 5 {
		t.Errorf("Expected at least 5 replication calls, got %d", len(calls))
	}

	// Verify PREPARE phase had both statements
	prepareCallFound := false
	for _, call := range calls {
		if call.Request.Phase == PhasePrep {
			if len(call.Request.Statements) != 2 {
				t.Errorf("PREPARE phase should have 2 statements, got %d", len(call.Request.Statements))
			}
			prepareCallFound = true
			break
		}
	}

	if !prepareCallFound {
		t.Errorf("No PREPARE phase call found")
	}
}

// TestLockWaitingOnConflict verifies that when a transaction encounters
// a conflict, it waits for the lock to be released instead of failing immediately
func TestLockWaitingOnConflict(t *testing.T) {
	nodes := []uint64{1, 2, 3}
	nodeProvider := newMockNodeProvider(nodes)
	replicator := newMockReplicator()
	clock := hlc.NewClock(1)

	coordinator := NewWriteCoordinator(1, nodeProvider, replicator, replicator, 5*time.Second, clock)

	// Transaction 1: Locks row with id=1
	txn1 := &Transaction{
		ID:               100,
		NodeID:           1,
		Statements:       []protocol.Statement{{SQL: "UPDATE users SET name='Alice' WHERE id=1", Type: protocol.StatementUpdate}},
		StartTS:          hlc.Timestamp{WallTime: 1000, Logical: 0},
		WriteConsistency: protocol.ConsistencyQuorum,
		Database:         "testdb",
	}

	// Start txn1 in background - it will hold the lock
	var txn1Err error
	txn1Done := make(chan struct{})
	go func() {
		txn1Err = coordinator.WriteTransaction(context.Background(), txn1)
		close(txn1Done)
	}()

	// Give txn1 time to acquire lock
	time.Sleep(100 * time.Millisecond)

	// Transaction 2: Tries to update same row - should WAIT, not fail immediately
	txn2 := &Transaction{
		ID:               200,
		NodeID:           1,
		Statements:       []protocol.Statement{{SQL: "UPDATE users SET name='Bob' WHERE id=1", Type: protocol.StatementUpdate}},
		StartTS:          hlc.Timestamp{WallTime: 2000, Logical: 0},
		WriteConsistency: protocol.ConsistencyQuorum,
		Database:         "testdb",
	}

	// Configure replicator to simulate conflict on node 2
	replicator.SetConflict(2, txn1.ID, "row locked by txn 100")

	// Start txn2 - it should BLOCK waiting for txn1
	var txn2Err error
	txn2Start := time.Now()
	txn2Done := make(chan struct{})
	go func() {
		txn2Err = coordinator.WriteTransaction(context.Background(), txn2)
		close(txn2Done)
	}()

	// Give txn2 time to encounter conflict and start waiting
	time.Sleep(200 * time.Millisecond)

	// Release txn1's lock by clearing conflict
	replicator.ClearConflict(2, txn1.ID)

	// Wait for both transactions to complete
	<-txn1Done
	<-txn2Done
	txn2Duration := time.Since(txn2Start)

	// Verify: txn1 succeeded
	if txn1Err != nil {
		t.Errorf("txn1 should succeed, got error: %v", txn1Err)
	}

	// Verify: txn2 succeeded after waiting
	if txn2Err != nil {
		t.Errorf("txn2 should succeed after waiting, got error: %v", txn2Err)
	}

	// Verify: txn2 waited at least 200ms
	if txn2Duration < 200*time.Millisecond {
		t.Errorf("txn2 should have waited, but completed in %v", txn2Duration)
	}
}

// TestConflictReturnsDeadlockError verifies that when a conflict is detected,
// the transaction fails immediately with MySQL error 1213 (deadlock) for client retry
func TestConflictReturnsDeadlockError(t *testing.T) {
	nodes := []uint64{1, 2, 3}
	nodeProvider := newMockNodeProvider(nodes)
	replicator := newMockReplicator()
	clock := hlc.NewClock(1)

	coordinator := NewWriteCoordinator(1, nodeProvider, replicator, replicator, 500*time.Millisecond, clock)

	// Transaction that will conflict
	txn := &Transaction{
		ID:               200,
		NodeID:           1,
		Statements:       []protocol.Statement{{SQL: "UPDATE users SET name='Bob' WHERE id=1", Type: protocol.StatementUpdate}},
		StartTS:          hlc.Timestamp{WallTime: 2000, Logical: 0},
		WriteConsistency: protocol.ConsistencyQuorum,
		Database:         "testdb",
	}

	// Simulate conflict on nodes 2 and 3
	replicator.SetPersistentNodeConflict(2, "row locked by txn 100")
	replicator.SetPersistentNodeConflict(3, "row locked by txn 100")

	// Execute - should fail immediately with deadlock error
	start := time.Now()
	err := coordinator.WriteTransaction(context.Background(), txn)
	duration := time.Since(start)

	// Verify: Failed with error
	if err == nil {
		t.Fatal("Expected error, got success")
	}

	// Verify: Error is MySQL error 1213 (deadlock - client should retry)
	mysqlErr, ok := err.(*MySQLError)
	if !ok {
		t.Fatalf("Expected MySQLError, got %T: %v", err, err)
	}

	if mysqlErr.Code != 1213 {
		t.Errorf("Expected error code 1213 (deadlock), got %d", mysqlErr.Code)
	}

	if mysqlErr.SQLState != "40001" {
		t.Errorf("Expected SQLSTATE 40001, got %s", mysqlErr.SQLState)
	}

	// Verify: Failed quickly (no internal retry)
	if duration > 100*time.Millisecond {
		t.Errorf("Expected immediate failure, but waited %v", duration)
	}
}

// TestConcurrentTransactionsSerializeOnConflict verifies that multiple
// concurrent transactions on the same row execute serially (one waits for the other)
func TestConcurrentTransactionsSerializeOnConflict(t *testing.T) {
	nodes := []uint64{1, 2, 3}
	nodeProvider := newMockNodeProvider(nodes)
	replicator := newMockReplicator()
	clock := hlc.NewClock(1)

	coordinator := NewWriteCoordinator(1, nodeProvider, replicator, replicator, 5*time.Second, clock)

	// Track execution order
	var executionOrder []uint64
	var mu sync.Mutex

	// Launch 3 concurrent transactions updating same row
	var wg sync.WaitGroup
	for i := uint64(1); i <= 3; i++ {
		wg.Add(1)
		go func(txnID uint64) {
			defer wg.Done()

			txn := &Transaction{
				ID:               txnID,
				NodeID:           1,
				Statements:       []protocol.Statement{{SQL: fmt.Sprintf("UPDATE users SET name='User%d' WHERE id=1", txnID), Type: protocol.StatementUpdate}},
				StartTS:          hlc.Timestamp{WallTime: int64(1000 * txnID), Logical: 0},
				WriteConsistency: protocol.ConsistencyQuorum,
				Database:         "testdb",
			}

			// Simulate that each transaction takes 100ms to complete
			replicator.SetDelay(txnID, 100*time.Millisecond)

			err := coordinator.WriteTransaction(context.Background(), txn)
			if err != nil {
				t.Errorf("Transaction %d failed: %v", txnID, err)
			}

			mu.Lock()
			executionOrder = append(executionOrder, txnID)
			mu.Unlock()
		}(i)

		// Stagger start times slightly
		time.Sleep(10 * time.Millisecond)
	}

	wg.Wait()

	// Verify: All 3 transactions completed
	if len(executionOrder) != 3 {
		t.Errorf("Expected 3 transactions to complete, got %d", len(executionOrder))
	}

	// Verify: They executed serially (no two overlapped)
	// This is verified implicitly by the fact they all succeeded without conflict errors
}

// TestDeadlockDetection verifies that circular lock dependencies are detected
// and one transaction is aborted with MySQL error 1213
func TestDeadlockDetection(t *testing.T) {
	t.Skip("Deadlock detection to be implemented in phase 2")

	// This test will be implemented when we add deadlock detection
	// For now, we just do lock waiting with timeout
}

// TestNoWaitOnDifferentRows verifies that transactions on different rows
// proceed in parallel without waiting
func TestNoWaitOnDifferentRows(t *testing.T) {
	nodes := []uint64{1, 2, 3}
	nodeProvider := newMockNodeProvider(nodes)
	replicator := newMockReplicator()
	clock := hlc.NewClock(1)

	coordinator := NewWriteCoordinator(1, nodeProvider, replicator, replicator, 5*time.Second, clock)

	// Track concurrent execution
	var activeCount atomic.Int32
	var maxConcurrent atomic.Int32

	// Launch 5 concurrent transactions on different rows
	var wg sync.WaitGroup
	for i := uint64(1); i <= 5; i++ {
		wg.Add(1)
		go func(rowID uint64) {
			defer wg.Done()

			// Increment active count
			current := activeCount.Add(1)

			// Track max concurrency
			for {
				max := maxConcurrent.Load()
				if current <= max || maxConcurrent.CompareAndSwap(max, current) {
					break
				}
			}

			txn := &Transaction{
				ID:               rowID * 100,
				NodeID:           1,
				Statements:       []protocol.Statement{{SQL: fmt.Sprintf("UPDATE users SET name='User' WHERE id=%d", rowID), Type: protocol.StatementUpdate}},
				StartTS:          hlc.Timestamp{WallTime: int64(1000 * rowID), Logical: 0},
				WriteConsistency: protocol.ConsistencyQuorum,
				Database:         "testdb",
			}

			// Simulate each transaction taking 200ms
			replicator.SetDelay(txn.ID, 200*time.Millisecond)

			err := coordinator.WriteTransaction(context.Background(), txn)
			if err != nil {
				t.Errorf("Transaction %d failed: %v", rowID, err)
			}

			activeCount.Add(-1)
		}(i)

		// Start them all at roughly the same time
		time.Sleep(10 * time.Millisecond)
	}

	wg.Wait()

	// Verify: Multiple transactions ran concurrently (should see at least 3 concurrent)
	if maxConcurrent.Load() < 3 {
		t.Errorf("Expected at least 3 concurrent transactions, got max %d", maxConcurrent.Load())
	}
}

// TransactionState tracks an open multi-statement transaction
type TransactionState struct {
	TxnID      uint64
	StartTS    hlc.Timestamp
	Statements []protocol.Statement
	Database   string
}
